// Copyright © 2023 Cisco Systems, Inc. and its affiliates.
// All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package exploitdb

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"strings"
	"time"

	"github.com/openclarity/openclarity/scanner/families/exploits/types"

	"github.com/cenkalti/backoff"
	log "github.com/sirupsen/logrus"
	exploitmodels "github.com/vulsio/go-exploitdb/models"
)

const (
	globalTimeoutMin = 10
	taskTimeoutSec   = 30
)

type exploitResponse struct {
	request exploitRequest
	json    string
}

type exploitRequest struct {
	cveID string
}

func getExploitsViaHTTP(ctx context.Context, cveIDs []string, urlPrefix string) ([]exploitResponse, error) {
	var responses []exploitResponse

	numCVEs := len(cveIDs)
	resChan := make(chan exploitResponse, numCVEs)
	errChan := make(chan error, numCVEs)
	defer close(resChan)
	defer close(errChan)

	concurrency := 10
	tasks := genWorkers(concurrency)
	for _, cveID := range cveIDs {
		// declare a new variable to prevent the loop variable
		// affecting the closure function below.
		cveid := cveID

		// Pass task to concurrent workers for processing
		tasks <- func() {
			req := exploitRequest{cveID: cveid}
			URL, err := url.JoinPath(urlPrefix, cveid)
			if err != nil {
				errChan <- err
				return
			}
			log.Debugf("HTTP Request to %s", URL)
			httpGetExploit(ctx, URL, req, resChan, errChan)
		}
	}

	timeout := time.After(globalTimeoutMin * time.Minute)
	var errs []error
	for range numCVEs {
		select {
		case res := <-resChan:
			responses = append(responses, res)
		case err := <-errChan:
			errs = append(errs, err)
		case <-timeout:
			return nil, errors.New("timeout fetching exploits")
		}
	}
	if len(errs) != 0 {
		return nil, fmt.Errorf("failed to fetch exploit: %v", errs)
	}

	return responses, nil
}

func httpGetExploit(ctx context.Context, url string, req exploitRequest, resChan chan<- exploitResponse, errChan chan<- error) {
	var body string
	var resp *http.Response
	var maxRetries uint64 = 3

	requestFn := func() error {
		ctx, cancel := context.WithTimeout(ctx, taskTimeoutSec*time.Second)
		defer cancel()

		r, err := http.NewRequest(http.MethodGet, url, nil)
		if err != nil {
			return fmt.Errorf("failed to create request: %w", err)
		}

		resp, err = http.DefaultClient.Do(r.WithContext(ctx)) // nolint:bodyclose
		if err != nil {
			return fmt.Errorf("failed to send GET request to url: %s: %w", url, err)
		}
		defer resp.Body.Close()

		if resp.StatusCode != http.StatusOK {
			return fmt.Errorf("received error response to GET request to url: %s, resp: %v: %w", url, resp, err)
		}

		b, err := io.ReadAll(resp.Body)
		if err != nil {
			return fmt.Errorf("failed to read response body: %w", err)
		}

		body = string(b)

		return nil
	}

	notifyFn := func(err error, t time.Duration) {
		log.Warnf("Failed to HTTP GET. retrying in %s seconds: %+v", t, err)
	}

	err := backoff.RetryNotify(requestFn, backoff.WithMaxRetries(backoff.NewExponentialBackOff(), maxRetries), notifyFn)
	if err != nil {
		errChan <- fmt.Errorf("HTTP Error %w", err)
		return
	}

	resChan <- exploitResponse{
		request: req,
		json:    body,
	}
}

// convertToCommonExploits converts exploit model to openclarity model.
func convertToCommonExploits(es []exploitmodels.Exploit, cveID string) []types.Exploit {
	// nolint:prealloc
	var exploits []types.Exploit
	for _, e := range es {
		exp := types.Exploit{
			// Exploit DB calculates a unique ID like this:
			// https://github.com/vulsio/go-exploitdb/blob/master/fetcher/inthewild.go#L110
			// which takes into account the URL, CVEID and some
			// other fields which are specific to the SourceDB
			// type. For now we'll use the same thing, but we may
			// want to standardise how we generate that ID across
			// all our Exploit scanners so that if multiple
			// scanners find the same exploit from the same source
			// DB we can match them together and de-duplicate.
			ID:          e.ExploitUniqueID,
			Name:        "",
			Title:       "",
			Description: e.Description,
			CveID:       cveID,
			URLs:        []string{e.URL},
			SourceDB:    string(e.ExploitType),
		}
		exploits = append(exploits, exp)
	}
	return exploits
}

// genWorkers generates goroutine
// http://qiita.com/na-o-ys/items/65373132b1c5bc973cca
// code taken from https://github.com/future-architect/vuls/blob/bfe0db77b4e16e3099a1e58b8db8f18120a11117/util/util.go#L16
func genWorkers(num int) chan<- func() {
	tasks := make(chan func())
	for range num {
		go func() {
			defer func() {
				if p := recover(); p != nil {
					log.Errorf("run time panic: %+v", p)
				}
			}()
			for f := range tasks {
				f()
			}
		}()
	}
	return tasks
}

// stringToArray convert comma separated string into string array.
// example: str1,str2,str3 => [str1,str2,str3]
func stringToArray(str string) []string {
	if str == "" {
		return nil
	}

	// remove whitespaces.
	str = strings.ReplaceAll(str, " ", "")

	cveIDs := strings.Split(str, ",")
	return cveIDs
}
